{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Machine Translation and the Dataset\n",
    ":label:`sec_machine_translation`\n",
    "\n",
    "We have used RNNs to design language models,\n",
    "which are key to natural language processing.\n",
    "Another flagship benchmark is *machine translation*,\n",
    "a central problem domain for *sequence transduction* models\n",
    "that transform input sequences into output sequences.\n",
    "Playing a crucial role in various modern AI applications,\n",
    "sequence transduction models will form the focus of the remainder of this chapter\n",
    "and :numref:`chap_attention`.\n",
    "To this end,\n",
    "this section introduces the machine translation problem\n",
    "and its dataset that will be used later.\n",
    "\n",
    "\n",
    "*Machine translation* refers to the\n",
    "automatic translation of a sequence\n",
    "from one language to another.\n",
    "In fact, this field\n",
    "may date back to 1940s\n",
    "soon after digital computers were invented,\n",
    "especially by considering the use of computers\n",
    "for cracking language codes in World War II.\n",
    "For decades,\n",
    "statistical approaches\n",
    "had been dominant in this field :cite:`Brown.Cocke.Della-Pietra.ea.1988,Brown.Cocke.Della-Pietra.ea.1990`\n",
    "before the rise\n",
    "of\n",
    "end-to-end learning using\n",
    "neural networks.\n",
    "The latter\n",
    "is often called\n",
    "*neural machine translation*\n",
    "to distinguish itself from\n",
    "*statistical machine translation*\n",
    "that involves statistical analysis\n",
    "in components such as\n",
    "the translation model and the language model.\n",
    "\n",
    "\n",
    "Emphasizing end-to-end learning,\n",
    "this book will focus on neural machine translation methods.\n",
    "Different from our language model problem\n",
    "in :numref:`sec_language_model`\n",
    "whose corpus is in one single language,\n",
    "machine translation datasets\n",
    "are composed of pairs of text sequences\n",
    "that are in\n",
    "the source language and the target language, respectively.\n",
    "Thus,\n",
    "instead of reusing the preprocessing routine\n",
    "for language modeling,\n",
    "we need a different way to preprocess\n",
    "machine translation datasets.\n",
    "In the following,\n",
    "we show how to\n",
    "load the preprocessed data\n",
    "into minibatches for training.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "\n",
    "%load ../utils/timemachine/Vocab.java\n",
    "%load ../utils/timemachine/RNNModel.java\n",
    "%load ../utils/timemachine/RNNModelScratch.java\n",
    "%load ../utils/timemachine/TimeMachine.java\n",
    "%load ../utils/timemachine/TimeMachineDataset.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.nio.charset.*;\n",
    "import java.util.zip.*;\n",
    "import java.util.stream.*;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "## Downloading and Preprocessing the Dataset\n",
    "\n",
    "To begin with,\n",
    "we download an English-French dataset\n",
    "that consists of [bilingual sentence pairs from the Tatoeba Project](http://www.manythings.org/anki/).\n",
    "Each line in the dataset\n",
    "is a tab-delimited pair\n",
    "of an English text sequence\n",
    "and the translated French text sequence.\n",
    "Note that each text sequence\n",
    "can be just one sentence or a paragraph of multiple sentences.\n",
    "In this machine translation problem\n",
    "where English is translated into French,\n",
    "English is the *source language*\n",
    "and French is the *target language*.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static String readDataNMT() throws IOException {\n",
    "    DownloadUtils.download(\n",
    "            \"http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip\", \"fra-eng.zip\");\n",
    "    ZipFile zipFile = new ZipFile(new File(\"fra-eng.zip\"));\n",
    "    Enumeration<? extends ZipEntry> entries = zipFile.entries();\n",
    "    while (entries.hasMoreElements()) {\n",
    "        ZipEntry entry = entries.nextElement();\n",
    "        if (entry.getName().contains(\"fra.txt\")) {\n",
    "            InputStream stream = zipFile.getInputStream(entry);\n",
    "            return new String(stream.readAllBytes(), StandardCharsets.UTF_8);\n",
    "        }\n",
    "    }\n",
    "    return null;\n",
    "}\n",
    "\n",
    "String rawText = readDataNMT();\n",
    "System.out.println(rawText.substring(0, 75));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "After downloading the dataset,\n",
    "we proceed with several preprocessing steps\n",
    "for the raw text data.\n",
    "For instance,\n",
    "we replace non-breaking space with space,\n",
    "convert uppercase letters to lowercase ones,\n",
    "and insert space between words and punctuation marks.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static String preprocessNMT(String text) {\n",
    "    // Replace non-breaking space with space, and convert uppercase letters to\n",
    "    // lowercase ones\n",
    "\n",
    "    text = text.replace('\\u202f', ' ').replaceAll(\"\\\\xa0\", \" \").toLowerCase();\n",
    "\n",
    "    // Insert space between words and punctuation marks\n",
    "    StringBuilder out = new StringBuilder();\n",
    "    Character currChar;\n",
    "    for (int i = 0; i < text.length(); i++) {\n",
    "        currChar = text.charAt(i);\n",
    "        if (i > 0 && noSpace(currChar, text.charAt(i - 1))) {\n",
    "            out.append(' ');\n",
    "        }\n",
    "        out.append(currChar);\n",
    "    }\n",
    "    return out.toString();\n",
    "}\n",
    "\n",
    "public static boolean noSpace(Character currChar, Character prevChar) {\n",
    "    /* Preprocess the English-French dataset. */\n",
    "    return new HashSet<>(Arrays.asList(',', '.', '!', '?')).contains(currChar)\n",
    "            && prevChar != ' ';\n",
    "}\n",
    "\n",
    "String text = preprocessNMT(rawText);\n",
    "System.out.println(text.substring(0, 80));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "## Tokenization\n",
    "\n",
    "Different from character-level tokenization\n",
    "in :numref:`sec_language_model`,\n",
    "for machine translation\n",
    "we prefer word-level tokenization here\n",
    "(state-of-the-art models may use more advanced tokenization techniques).\n",
    "The following `tokenizeNMT` function\n",
    "tokenizes the the first `numExamples` text sequence pairs,\n",
    "where\n",
    "each token is either a word or a punctuation mark.\n",
    "This function returns\n",
    "two lists of token lists: `source` and `target`.\n",
    "Specifically,\n",
    "`source.get(i)` is a list of tokens from the\n",
    "$i^\\mathrm{th}$ text sequence in the source language (English here) and `target.get(i)` is that in the target language (French here).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<ArrayList<String[]>, ArrayList<String[]>> tokenizeNMT(\n",
    "        String text, Integer numExamples) {\n",
    "    ArrayList<String[]> source = new ArrayList<>();\n",
    "    ArrayList<String[]> target = new ArrayList<>();\n",
    "\n",
    "    int i = 0;\n",
    "    for (String line : text.split(\"\\n\")) {\n",
    "        if (numExamples != null && i > numExamples) {\n",
    "            break;\n",
    "        }\n",
    "        String[] parts = line.split(\"\\t\");\n",
    "        if (parts.length == 2) {\n",
    "            source.add(parts[0].split(\" \"));\n",
    "            target.add(parts[1].split(\" \"));\n",
    "        }\n",
    "        i += 1;\n",
    "    }\n",
    "    return new Pair<>(source, target);\n",
    "}\n",
    "\n",
    "Pair<ArrayList<String[]>, ArrayList<String[]>> pair = tokenizeNMT(text.toString(), null);\n",
    "ArrayList<String[]> source = pair.getKey();\n",
    "ArrayList<String[]> target = pair.getValue();\n",
    "for (String[] subArr : source.subList(0, 6)) {\n",
    "    System.out.println(Arrays.toString(subArr));\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for (String[] subArr : target.subList(0, 6)) {\n",
    "    System.out.println(Arrays.toString(subArr));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "Let us plot the histogram of the number of tokens per text sequence.\n",
    "In this simple English-French dataset,\n",
    "most of the text sequences have fewer than 20 tokens.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "double[] y1 = new double[source.size()];\n",
    "for (int i = 0; i < source.size(); i++) y1[i] = source.get(i).length;\n",
    "double[] y2 = new double[target.size()];\n",
    "for (int i = 0; i < target.size(); i++) y2[i] = target.get(i).length;\n",
    "\n",
    "HistogramTrace trace1 =\n",
    "        HistogramTrace.builder(y1).opacity(.75).name(\"source\").nBinsX(20).build();\n",
    "HistogramTrace trace2 =\n",
    "        HistogramTrace.builder(y2).opacity(.75).name(\"target\").nBinsX(20).build();\n",
    "\n",
    "Layout layout = Layout.builder().barMode(Layout.BarMode.GROUP).build();\n",
    "new Figure(layout, trace1, trace2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "## Vocabulary\n",
    "\n",
    "Since the machine translation dataset\n",
    "consists of pairs of languages,\n",
    "we can build two vocabularies for\n",
    "both the source language and\n",
    "the target language separately.\n",
    "With word-level tokenization,\n",
    "the vocabulary size will be significantly larger\n",
    "than that using character-level tokenization.\n",
    "To alleviate this,\n",
    "here we treat infrequent tokens\n",
    "that appear less than 2 times\n",
    "as the same unknown (\"&lt;unk&gt;\") token.\n",
    "Besides that,\n",
    "we specify additional special tokens\n",
    "such as for padding (\"&lt;pad&gt;\") sequences to the same length in minibatches,\n",
    "and for marking the beginning (\"&lt;bos&gt;\") or end (\"&lt;eos&gt;\") of sequences.\n",
    "Such special tokens are commonly used in\n",
    "natural language processing tasks.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Vocab srcVocab =\n",
    "                new Vocab(\n",
    "                        source.stream().toArray(String[][]::new),\n",
    "                        2,\n",
    "                        new String[] {\"<pad>\", \"<bos>\", \"<eos>\"});\n",
    "System.out.println(srcVocab.length());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "## Loading the Dataset\n",
    ":label:`subsec_mt_data_loading`\n",
    "\n",
    "Recall that in language modeling\n",
    "each sequence example,\n",
    "either a segment of one sentence\n",
    "or a span over multiple sentences,\n",
    "has a fixed length.\n",
    "This was specified by the `numSteps`\n",
    "(number of time steps or tokens) argument in :numref:`sec_language_model`.\n",
    "In machine translation, each example is\n",
    "a pair of source and target text sequences,\n",
    "where each text sequence may have different lengths.\n",
    "\n",
    "For computational efficiency,\n",
    "we can still process a minibatch of text sequences\n",
    "at one time by *truncation* and *padding*.\n",
    "Suppose that every sequence in the same minibatch\n",
    "should have the same length `numSteps`.\n",
    "If a text sequence has fewer than `numSteps` tokens,\n",
    "we will keep appending the special \"&lt;pad&gt;\" token\n",
    "to its end until its length reaches `numSteps`.\n",
    "Otherwise,\n",
    "we will truncate the text sequence\n",
    "by only taking its first `numSteps` tokens\n",
    "and discarding the remaining.\n",
    "In this way,\n",
    "every text sequence\n",
    "will have the same length\n",
    "to be loaded in minibatches of the same shape.\n",
    "\n",
    "The following `truncatePad` function\n",
    "truncates or pads text sequences as described before.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static int[] truncatePad(Integer[] integerLine, int numSteps, int paddingToken) {\n",
    "    /* Truncate or pad sequences */\n",
    "    int[] line = Arrays.stream(integerLine).mapToInt(i -> i).toArray();\n",
    "    if (line.length > numSteps) {\n",
    "        return Arrays.copyOfRange(line, 0, numSteps);\n",
    "    }\n",
    "    int[] paddingTokenArr = new int[numSteps - line.length]; // Pad\n",
    "    Arrays.fill(paddingTokenArr, paddingToken);\n",
    "\n",
    "    return IntStream.concat(Arrays.stream(line), Arrays.stream(paddingTokenArr)).toArray();\n",
    "}\n",
    "\n",
    "int[] result = truncatePad(srcVocab.getIdxs(source.get(0)), 10, srcVocab.getIdx(\"<pad>\"));\n",
    "System.out.println(Arrays.toString(result));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "Now we define a function to transform\n",
    "text sequences into minibatches for training.\n",
    "We append the special “&lt;eos&gt;” token\n",
    "to the end of every sequence to indicate the\n",
    "end of the sequence.\n",
    "When a model is predicting\n",
    "by\n",
    "generating a sequence token after token,\n",
    "the generation\n",
    "of the “&lt;eos&gt;” token\n",
    "can suggest that\n",
    "the output sequence is complete.\n",
    "Besides,\n",
    "we also record the length\n",
    "of each text sequence excluding the padding tokens.\n",
    "This information will be needed by\n",
    "some models that\n",
    "we will cover later.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<NDArray, NDArray> buildArrayNMT(\n",
    "        List<String[]> lines, Vocab vocab, int numSteps) {\n",
    "    /* Transform text sequences of machine translation into minibatches. */\n",
    "    List<Integer[]> linesIntArr = new ArrayList<>();\n",
    "    for (String[] strings : lines) {\n",
    "        linesIntArr.add(vocab.getIdxs(strings));\n",
    "    }\n",
    "    for (int i = 0; i < linesIntArr.size(); i++) {\n",
    "        List<Integer> temp = new ArrayList<>(Arrays.asList(linesIntArr.get(i)));\n",
    "        temp.add(vocab.getIdx(\"<eos>\"));\n",
    "        linesIntArr.set(i, temp.toArray(new Integer[0]));\n",
    "    }\n",
    "\n",
    "    NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "    NDArray arr = manager.create(new Shape(linesIntArr.size(), numSteps), DataType.INT32);\n",
    "    int row = 0;\n",
    "    for (Integer[] line : linesIntArr) {\n",
    "        NDArray rowArr = manager.create(truncatePad(line, numSteps, vocab.getIdx(\"<pad>\")));\n",
    "        arr.set(new NDIndex(\"{}:\", row), rowArr);\n",
    "        row += 1;\n",
    "    }\n",
    "    NDArray validLen = arr.neq(vocab.getIdx(\"<pad>\")).sum(new int[] {1});\n",
    "    return new Pair<>(arr, validLen);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "## Putting All Things Together\n",
    "\n",
    "Finally, we define the `loadDataNMT` function\n",
    "to return the data iterator, together with\n",
    "the vocabularies for both the source language and the target language.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<ArrayDataset, Pair<Vocab, Vocab>> loadDataNMT(\n",
    "        int batchSize, int numSteps, int numExamples) throws IOException {\n",
    "    /* Return the iterator and the vocabularies of the translation dataset. */\n",
    "    String text = preprocessNMT(readDataNMT());\n",
    "    Pair<ArrayList<String[]>, ArrayList<String[]>> pair = tokenizeNMT(text, numExamples);\n",
    "    ArrayList<String[]> source = pair.getKey();\n",
    "    ArrayList<String[]> target = pair.getValue();\n",
    "    Vocab srcVocab =\n",
    "            new Vocab(\n",
    "                    source.toArray(String[][]::new),\n",
    "                    2,\n",
    "                    new String[] {\"<pad>\", \"<bos>\", \"<eos>\"});\n",
    "    Vocab tgtVocab =\n",
    "            new Vocab(\n",
    "                    target.toArray(String[][]::new),\n",
    "                    2,\n",
    "                    new String[] {\"<pad>\", \"<bos>\", \"<eos>\"});\n",
    "\n",
    "    Pair<NDArray, NDArray> pairArr = buildArrayNMT(source, srcVocab, numSteps);\n",
    "    NDArray srcArr = pairArr.getKey();\n",
    "    NDArray srcValidLen = pairArr.getValue();\n",
    "\n",
    "    pairArr = buildArrayNMT(target, tgtVocab, numSteps);\n",
    "    NDArray tgtArr = pairArr.getKey();\n",
    "    NDArray tgtValidLen = pairArr.getValue();\n",
    "\n",
    "    ArrayDataset dataset =\n",
    "            new ArrayDataset.Builder()\n",
    "                    .setData(srcArr, srcValidLen)\n",
    "                    .optLabels(tgtArr, tgtValidLen)\n",
    "                    .setSampling(batchSize, true)\n",
    "                    .build();\n",
    "\n",
    "    return new Pair<>(dataset, new Pair<>(srcVocab, tgtVocab));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "Let us read the first minibatch from the English-French dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Pair<ArrayDataset, Pair<Vocab, Vocab>> output = loadDataNMT(2, 8, 600);\n",
    "ArrayDataset dataset = output.getKey();\n",
    "srcVocab = output.getValue().getKey();\n",
    "Vocab tgtVocab = output.getValue().getValue();\n",
    "\n",
    "Batch batch = dataset.getData(manager).iterator().next();\n",
    "NDArray X = batch.getData().get(0);\n",
    "NDArray xValidLen = batch.getData().get(1);\n",
    "NDArray Y = batch.getLabels().get(0);\n",
    "NDArray yValidLen = batch.getLabels().get(1);\n",
    "System.out.println(X);\n",
    "System.out.println(xValidLen);\n",
    "System.out.println(Y);\n",
    "System.out.println(yValidLen);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 22
   },
   "source": [
    "## Summary\n",
    "\n",
    "* Machine translation refers to the automatic translation of a sequence from one language to another.\n",
    "* Using word-level tokenization, the vocabulary size will be significantly larger than that using character-level tokenization. To alleviate this, we can treat infrequent tokens as the same unknown token.\n",
    "* We can truncate and pad text sequences so that all of them will have the same length to be loaded in minibatches.\n",
    "\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Try different values of the `numExamples` argument in the `loadDataNMT` function. How does this affect the vocabulary sizes of the source language and the target language?\n",
    "1. Text in some languages such as Chinese and Japanese does not have word boundary indicators (e.g., space). Is word-level tokenization still a good idea for such cases? Why or why not?\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
